Conversation
| from pyro.optim.clipped_adam import ClippedAdam as _ClippedAdam | ||
|
|
||
| import funsor | ||
| from funsor.adam import Adam # noqa: F401 |
There was a problem hiding this comment.
for compatibility with pyroapi
| value, _ = PARAM_STORE[name] | ||
| if event_dim is None: | ||
| event_dim = value.dim() | ||
| output = funsor.Reals[value.shape[value.dim() - event_dim :]] |
There was a problem hiding this comment.
infer output when pyro.param was already defined elsewhere
|
|
||
| def step(self, *args, **kwargs): | ||
| self.optim.num_steps = 1 | ||
| return self.run(*args, **kwargs) |
There was a problem hiding this comment.
for compatibility with SVI interface
There was a problem hiding this comment.
Hmm, let's think about alternative workarounds... One issue here is that the Adam optimizer statistics would not be persisted across svi steps.
One option is simply to change pyroapi's SVI interface to look for either .run() or if missing fall back to .step(). Also I think it's more important to create a simple didactic example than to fastidiously conform to the pyroapi interface (since that interface hasn't seen much use).
| for p in params: | ||
| p.grad = torch.zeros_like(p.grad) | ||
| return loss.item() | ||
| with funsor.terms.lazy: |
There was a problem hiding this comment.
lazy interpretation is needed here to make sure that funsor.Integrate is not eagerly expanded in Expectation
|
|
I think you're right, but let's discuss. That's a little different from Pyro where jit is baked into ELBO subclasses. |
Addresses #533
Group coded with @fritzo @eb8680 @fehiepsi